/*
 * This file is part of AdaptiveCpp, an implementation of SYCL and C++ standard
 * parallelism for CPUs and GPUs.
 *
 * Copyright The AdaptiveCpp Contributors
 *
 * AdaptiveCpp is released under the BSD 2-Clause "Simplified" License.
 * See file LICENSE in the project root for full license details.
 */
// SPDX-License-Identifier: BSD-2-Clause

#ifndef ACPP_ALGORITHMS_MERGE_PATH_HPP
#define ACPP_ALGORITHMS_MERGE_PATH_HPP

#include <cstddef>
#include <cstdint>
#include <iterator>

#include "../binary_search/index_search.hpp"

namespace hipsycl::algorithms::merging {


/// This implements the merge path algorithm, which can be used to decompose a merge
/// into N disjoint, independent merges which can be run in parallel. For details, see
/// Green et al. (2014): Merge Path - A Visually Intuitive Approach to Parallel Merging
/// https://arxiv.org/pdf/1406.2628
class merge_path {
public:
  template <class ForwardIt1, class ForwardIt2, class Compare, class Size>
  static void
  nth_independent_merge_begin(ForwardIt1 first1, ForwardIt1 last1,
                              ForwardIt2 first2, ForwardIt2 last2, Compare comp,
                              Size partition_index, Size partition_chunk_size,
                              Size &array1_pos_out, Size &array2_pos_out) {

    Size input1_size = static_cast<Size>(std::distance(first1, last1));
    Size input2_size = static_cast<Size>(std::distance(first2, last2));

    binary_diag_search(first1, last1, first2, last2, comp, input1_size,
                       input2_size, partition_index * partition_chunk_size, array1_pos_out,
                       array2_pos_out);
  }

  template <class ForwardIt1, class ForwardIt2, class Size>
  static constexpr Size
  num_independent_merges(ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
                 ForwardIt2 last2, Size segment_chunk_size) {
    Size input1_size = static_cast<Size>(std::distance(first1, last1));
    Size input2_size = static_cast<Size>(std::distance(first2, last2));

    auto num_diags = total_num_diags(input1_size, input2_size);

    return (num_diags + segment_chunk_size - 1) / segment_chunk_size;
  }

private:
  template<class ForwardIt, class Size>
  static auto load(ForwardIt first, Size idx) {
    std::advance(first, idx);
    return *first;
  }

  template<class ForwardIt, class T, class Size>
  static void store(ForwardIt first, Size idx, const T& val) {
    std::advance(first, idx);
    *first = val;
  }


  // Total number of left-bottom-top-right diagonals of the AB matrix
  template<class Size>
  static constexpr Size total_num_diags(Size size1, Size size2) {
    // There are size1 + size2 - 1 "real" diags, but we need an additional diagonal 0 before
    // the actual data
    return size1 + size2;
  }

  template <class ForwardIt1, class ForwardIt2, class Compare, class Size>
  static void
  binary_diag_search(ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
                     ForwardIt2 last2, Compare comp,
                     Size size1, Size size2,
                     Size diag_index, Size &array1_index_out,
                     Size &array2_index_out) {
    
    if(size1 <= 1 && size2 <= 1) {
      array1_index_out = 0;
      array2_index_out = 0;
      return;
    }

    Size dlen = diag_length(size1, size2, diag_index);
    

    if(dlen <= 1) {
      array1_index_out = 0;
      array2_index_out = 0;
      return;
    }

    // The idea behind the merge path algorithm is to create the merge matrix, where the
    // [i][j] entries are 1 exactly if comp(first1[i],first2[j]) == true, and 0
    // otherwise. This matrix will have a contiguous region of zeroes at the
    // top, the rest will be 1. We can then find the merge path by finding the
    // highest value where the cross-diagonals in the matrix are 1 using binary
    // search. Since we only ever care about the merge matrix when binary
    // searching on the diagonal, this function generates entries from the merge
    // matrix on-the-fly with just one parameter: the current position on the
    // diagonal.
    auto data_loader = [&](auto idx) {
      auto idx1 = array1_idx_from_diag(size1, size2, diag_index, idx);
      auto idx2 = array2_idx_from_diag(size1, size2, diag_index, idx);

      // Due to arcane reasons that cannot be expressed in mere mortal words,
      // the merge matrix needs to be shifted by -1 in the first dimension.
      // This was revealed to my in a dream.
      auto v1 = load(first1, idx1 == 0 ? 0 : idx1 - 1);
      auto v2 = load(first2, idx2);
      
      bool res = comp(v1, v2);
      return static_cast<int>(res);
    };

    auto compare = [&](int v1, int v2) {
      // Note: Do NOT use comp() here, since this is used to compare entries
      // in the merge matrix (which can only be 1 or 0 as generated by data_loader),
      // not used to compare elements of the user data array!
      return v1 < v2;
    };

    // Run binary serach across the index space [0, dlen) to find the first 1
    // from top to bottom on the current diagonal
    auto idx = binary_searching::index_upper_bound(Size{0}, dlen, 0,
                                                   data_loader, compare);

    array1_index_out = array1_idx_from_diag(size1, size2, diag_index, idx);
    array2_index_out = array2_idx_from_diag(size1, size2, diag_index, idx);
  }

  template <class Size>
  static constexpr Size diag_length(Size size1, Size size2, Size diag_idx) {
    if(diag_idx < size1 && diag_idx < size2)
      return diag_idx;

    auto min_size = std::min(size1, size2);
    auto max_size = std::max(size1, size2);

    if(diag_idx >= min_size && diag_idx <= max_size)
      return min_size;

    return total_num_diags(size1, size2) - diag_idx;
  }

  // position on diag is incremented from the top of the matrix to the bottom.
  template <class Size>
  static constexpr Size array1_idx_from_diag(Size size1, Size size2,
                                             Size diag_idx,
                                             Size position_on_diag) {
    // Note: We need to use size and *not* size-1 in this expression.
    // The position must be able to become invalid so that we can express when
    // we only need elements from array2 for the merge and we have left array1.
    auto diag_start = std::min(diag_idx, size1);
    return diag_start - position_on_diag;
  }

  template <class Size>
  static constexpr Size array2_idx_from_diag(Size size1, Size size2,
                                             Size diag_idx,
                                             Size position_on_diag) {
    if(diag_idx <= size1)
      return position_on_diag;
    return diag_idx - size1 + position_on_diag;
  }
};




}



#endif
